{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Аггрегация разметки датасета RCB"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Аггрегация строится по следующей системе:\n",
    "\n",
    "1. Сбор размеченных пулов с Толоки. Возможны варианты:\n",
    "    - только общий пул нужно аггрегировать, тогда забирается только он\n",
    "    - часть данных находится в контрольных заданиях и экзамене, тогда к основному пулу добавляются данные задания\n",
    "2. Фильтрация разметчиков:\n",
    "    - в общем пуле есть некоторое количество заранее размеченных заданий - контрольных\n",
    "    - хорошим считается разметчик, который показывает `accuracy >= 0.5` на данных заданиях\n",
    "    - формируется список \"плохих\" разметчиков\n",
    "3. Аггрегация ответов разметчиков по заданиям:\n",
    "    - форматирование в заданиях может отличаться от изначального из-за выгрузки с Толоки\n",
    "    - учитываются только ответы \"хороших\" разметчиков\n",
    "    - аггрегация по подготовленным пулам - создается массив карточек вида {key: value}, где key - кортеж из всех значимых элементов задания, value - список из кортежей вида (user_id, answer)\n",
    "4. Голосование большинством по каждому заданию:\n",
    "    - минимально необходимое большинство составляет 3 голоса, так как такое большинство валидно для перекрытия 5\n",
    "    - по результату формируется датафрейм с заданиями и ответами\n",
    "5. Подгрузка оригинальных данных с разметкой в виде таблицы с заданиями и ответами\n",
    "6. Соединение таблиц:\n",
    "    - очистка форматирования в таблице с ответами разметчиков и в таблице с правильными ответами\n",
    "    - создание единых столбцов с полным заданием\n",
    "    - соединение таблиц по данному столбцу\n",
    "    - валидация размеров\n",
    "7. Подсчет метрик"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Сбор данных разметки и фильтрация разметчиков"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Датасет для разметки состоит из 525 объектов."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "assignments = pd.read_csv('assignments_from_pool_42365874__11-12-2023.tsv', sep='\\t')\n",
    "skills = pd.read_csv('workerSkills.csv', sep='|')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Разметчикам предлагалось на гипотезы и ситуации определить их взаимосвязь: гипотеза может либо следовать из ситуации, либо противоречить ей, либо не быть с ней связанной.\n",
    "Вход: \n",
    "- INPUT:premise (пример: `Все утро шел сильный дождь.`).\n",
    "- INPUT:hypothesis (пример: `Я весь вымок.`).\n",
    "\n",
    "Выход:\n",
    "- OUTPUT:answer (одна из трех цифр: `1`, `2`, `3`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>INPUT:premise</th>\n",
       "      <th>INPUT:hypothesis</th>\n",
       "      <th>OUTPUT:answer</th>\n",
       "      <th>GOLDEN:answer</th>\n",
       "      <th>HINT:text</th>\n",
       "      <th>HINT:default_language</th>\n",
       "      <th>ASSIGNMENT:link</th>\n",
       "      <th>ASSIGNMENT:task_id</th>\n",
       "      <th>ASSIGNMENT:assignment_id</th>\n",
       "      <th>ASSIGNMENT:task_suite_id</th>\n",
       "      <th>ASSIGNMENT:worker_id</th>\n",
       "      <th>ASSIGNMENT:status</th>\n",
       "      <th>ASSIGNMENT:started</th>\n",
       "      <th>ASSIGNMENT:submitted</th>\n",
       "      <th>ASSIGNMENT:accepted</th>\n",
       "      <th>ASSIGNMENT:reward</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>В ближайшие несколько месяцев будут представле...</td>\n",
       "      <td>Сразу несколько аппаратов станут новым витком ...</td>\n",
       "      <td>1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>https://platform.toloka.ai/task/42365874/00028...</td>\n",
       "      <td>00028673b2--65661c962eae777d413ceec5</td>\n",
       "      <td>00028673b2--656626d8cb94b61ac9a31ba0</td>\n",
       "      <td>00028673b2--656626d8cb94b61ac9a31b9b</td>\n",
       "      <td>613cd9c009dce0c9dbdfed7a4fc2aa64</td>\n",
       "      <td>APPROVED</td>\n",
       "      <td>2023-11-28T17:43:52.392</td>\n",
       "      <td>2023-11-28T17:45:03.448</td>\n",
       "      <td>2023-11-28T17:45:03.448</td>\n",
       "      <td>0.03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       INPUT:premise  \\\n",
       "0  В ближайшие несколько месяцев будут представле...   \n",
       "\n",
       "                                    INPUT:hypothesis  OUTPUT:answer  \\\n",
       "0  Сразу несколько аппаратов станут новым витком ...              1   \n",
       "\n",
       "   GOLDEN:answer  HINT:text  HINT:default_language  \\\n",
       "0            NaN        NaN                    NaN   \n",
       "\n",
       "                                     ASSIGNMENT:link  \\\n",
       "0  https://platform.toloka.ai/task/42365874/00028...   \n",
       "\n",
       "                     ASSIGNMENT:task_id              ASSIGNMENT:assignment_id  \\\n",
       "0  00028673b2--65661c962eae777d413ceec5  00028673b2--656626d8cb94b61ac9a31ba0   \n",
       "\n",
       "               ASSIGNMENT:task_suite_id              ASSIGNMENT:worker_id  \\\n",
       "0  00028673b2--656626d8cb94b61ac9a31b9b  613cd9c009dce0c9dbdfed7a4fc2aa64   \n",
       "\n",
       "  ASSIGNMENT:status       ASSIGNMENT:started     ASSIGNMENT:submitted  \\\n",
       "0          APPROVED  2023-11-28T17:43:52.392  2023-11-28T17:45:03.448   \n",
       "\n",
       "       ASSIGNMENT:accepted  ASSIGNMENT:reward  \n",
       "0  2023-11-28T17:45:03.448               0.03  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "assignments.head(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Фильтруем толокеров с `accuracy < 0.5` на контрольных заданиях, чтобы не учитывать их ответы при подсчете метрик."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Users total:  229\n",
      "Bad users: 112\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "users_dict = defaultdict(lambda: defaultdict(int))\n",
    "\n",
    "for idx, row in assignments.iterrows():\n",
    "    text = row[0]\n",
    "\n",
    "    out = row[2]\n",
    "    \n",
    "    gold = row[3]\n",
    "\n",
    "    user = row[10]\n",
    "\n",
    "    if str(user) != \"nan\" and str(gold) != \"nan\":\n",
    "        if int(out) == int(gold):\n",
    "            users_dict[user][\"good\"] += 1\n",
    "        else:\n",
    "            users_dict[user][\"bad\"] += 1\n",
    "\n",
    "print(\"Users total: \", len(users_dict))\n",
    "bad_users = []\n",
    "for key, value in users_dict.items():\n",
    "    percentage_good = value[\"good\"]/(value[\"good\"] + value[\"bad\"])\n",
    "    if percentage_good < 0.5:\n",
    "        bad_users.append(key)\n",
    "\n",
    "print(\"Bad users:\", len(bad_users))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "112 из 229 разметчиков на контрольных заданиях показали слишком плохое качество, чтобы учитывать их ответы для расчета метрики."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Отделяем контроль от основы, так как контрольные задания создавались отдельно и не должны учитываться при подсчете метрик. На контрольных заданиях есть `GOLDEN:answer`. Также отсеиваем возможные баги Толоки, когда в строке может не быть задания - `INPUT:hypothesis` содержит NaN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "assignments_no_control = assignments[assignments['GOLDEN:answer'].isnull()]\n",
    "assignments_no_control_no_null = assignments_no_control[assignments_no_control['INPUT:hypothesis'].notnull()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Посчитаем, сколько было затрачено на получение разметки тестовых данных без учета контрольных заданий, так как они могли проходиться неограниченное количество раз одним и тем же разметчиком."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "взвешенная цена айтема в тесте: 0.034\n",
      "потрачено на разметку теста: 73.464\n",
      "73.464 / 0.034\n"
     ]
    }
   ],
   "source": [
    "def w_sum(df):\n",
    "    idx = df.index.values\n",
    "    vals = df.values\n",
    "    summ = idx * vals\n",
    "    return summ.sum()\n",
    "\n",
    "d1 = assignments_no_control_no_null['ASSIGNMENT:reward'].value_counts(normalize=True)\n",
    "d2 = assignments_no_control_no_null['ASSIGNMENT:reward'].value_counts()\n",
    "print(f'взвешенная цена айтема в тесте: {round(w_sum(d1), 3)}')\n",
    "print(f'потрачено на разметку теста: {round(w_sum(d2), 3)}')\n",
    "print(f'{round(w_sum(d2), 3)} / {round(w_sum(d1), 3)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Выделим, сколько составила средняя часовая ставка для разметки тестовой части датасета. Это будет простое среднее из следующих величин: количество заданий, которое разметчк может сделать за час на основе данного задания, помноженное на цену задания."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.6095295111869774"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_hour_pay(df):\n",
    "    try:\n",
    "        times = pd.to_datetime(df['ASSIGNMENT:submitted']) - pd.to_datetime(df['ASSIGNMENT:started'])\n",
    "    except Exception as e:\n",
    "        times = []\n",
    "        for i in range(len(assignments_no_control_no_null)):\n",
    "            try:\n",
    "                start = pd.to_datetime(assignments_no_control_no_null['ASSIGNMENT:started'].iloc[i])\n",
    "            except Exception as e:\n",
    "                start = pd.to_datetime(assignments_no_control_no_null['ASSIGNMENT:started'].apply(lambda x: x.split('T')[1]).iloc[i])\n",
    "            try:\n",
    "                end = pd.to_datetime(assignments_no_control_no_null['ASSIGNMENT:submitted'].iloc[i])\n",
    "            except Exception as e:\n",
    "                start = pd.to_datetime(assignments_no_control_no_null['ASSIGNMENT:submitted'].apply(lambda x: x.split('T')[1]).iloc[i])\n",
    "            delta = end - start\n",
    "            times.extend([delta])\n",
    "        times = pd.Series(times)\n",
    "        # times = pd.to_datetime(df['ASSIGNMENT:submitted'].apply(lambda x: x.split('T')[1])) - pd.to_datetime(df['ASSIGNMENT:started'].apply(lambda x: x.split('T')[1]))\n",
    "    sums = 3600 / times.apply(lambda x: x.seconds) * df['ASSIGNMENT:reward']\n",
    "    return sums.mean()\n",
    "\n",
    "get_hour_pay(assignments_no_control_no_null)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Сбор ответов разметчиков и голосование"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Собираем ответы голосования большинством для каждого задания."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "438\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "text_dict = defaultdict(list)\n",
    "\n",
    "for premise, hypothesis, user, out in zip(\n",
    "    assignments_no_control_no_null[\"INPUT:premise\"], assignments_no_control_no_null[\"INPUT:hypothesis\"],\n",
    "    assignments_no_control_no_null[\"ASSIGNMENT:worker_id\"], assignments_no_control_no_null[\"OUTPUT:answer\"]\n",
    "    ):\n",
    "    if user not in bad_users:\n",
    "        text_dict[(premise, hypothesis)].append([\n",
    "                user,\n",
    "                {\"out\": out}\n",
    "        ])\n",
    "\n",
    "print(len(text_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({5: 172, 4: 168, 3: 76, 2: 18, 1: 4})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "keys = list(text_dict.keys())\n",
    "Counter([len(text_dict[keys[i]]) for i in range(len(keys))])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Есть 110 заданий, где перекрытие меньше 5. Для формирования итоговых лейблов нужно, чтобы было простое большинство разметчиков, проголосовавших за данную опцию. Если большинства нет, то оценка строится, исходя из оценки навыков разметчиков. В таком случае, финальный лейбл будет присвоен по голосу группы с наилучшими навыками."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_full = {}\n",
    "user2skill = {k:v for k, v in zip(skills['worker_id'], skills['skill_value'])}\n",
    "control_acc = assignments[assignments['GOLDEN:answer'].notna()]\\\n",
    "    .groupby('ASSIGNMENT:worker_id')\\\n",
    "        .apply(lambda x: (np.array(x['OUTPUT:answer']) == np.array(x['GOLDEN:answer'])).mean())\n",
    "user2control = {k:v for k, v in zip(control_acc.index, control_acc.values)}\n",
    "\n",
    "from crowdkit.aggregation.classification.glad import GLAD\n",
    "\n",
    "full = assignments['INPUT:premise'] + ' ' + assignments['INPUT:hypothesis']\n",
    "id2task = dict(enumerate(full))\n",
    "task2id = {k:v for v, k in id2task.items()}\n",
    "id2user = dict(enumerate(assignments['ASSIGNMENT:worker_id']))\n",
    "user2id = {k:v for v, k in id2user.items()}\n",
    "\n",
    "codes = full.map(task2id)\n",
    "res = pd.DataFrame({'task': codes, 'worker': assignments['ASSIGNMENT:worker_id'].map(user2id), 'label': assignments['OUTPUT:answer']})\n",
    "model = GLAD(n_iter=10000, tol=1e-06, m_step_max_iter=1000, m_step_tol=1e-03)\n",
    "model.fit(res)\n",
    "user2alpha = dict(enumerate(model.alphas_))\n",
    "tb = model.alphas_.copy()\n",
    "tb.index = tb.index.map(id2user)\n",
    "user2alpha = {k:v for k, v in zip(tb.index, tb.values)}\n",
    "\n",
    "stats = {\n",
    "    'total_agreement': 0,\n",
    "    'majority': 0,\n",
    "    'skill_based': 0,\n",
    "    'major_based': 0,\n",
    "    'em_based': 0,\n",
    "    'rest': 0,\n",
    "}\n",
    "\n",
    "for i in range(len(keys)):\n",
    "    ans = text_dict[keys[i]]\n",
    "    lst = [[ans[j][0], ans[j][1]['out']] for j in range(len(ans))]\n",
    "    users, votes = list(zip(*lst))\n",
    "    cnt = pd.Series(Counter(votes)).sort_values(ascending=False)\n",
    "\n",
    "    # # total agreement\n",
    "    if len(cnt) == 1:\n",
    "        res = cnt.index[0]\n",
    "        stats['total_agreement'] += 1\n",
    "    # simple majority\n",
    "    elif cnt.iloc[0] > cnt.iloc[1]:\n",
    "        res = cnt.index[0]\n",
    "        stats['majority'] += 1\n",
    "    # (> 1 options) & (1 option == 2 option)\n",
    "    else:\n",
    "        # try overall skill based comparison\n",
    "        vals = list(map(lambda x: user2skill[x], users))\n",
    "        table = pd.DataFrame({'user': users, 'votes': votes, 'skill': vals})\n",
    "        agg = table.groupby('votes').agg(\n",
    "            sum_skill=pd.NamedAgg(column='skill', aggfunc='sum'),\n",
    "            sum_votes=pd.NamedAgg(column='user', aggfunc='count')\n",
    "        ).sort_values(by=['sum_votes', 'sum_skill'], ascending=False)\n",
    "        # check there is a leader by skills\n",
    "        if agg['sum_skill'].iloc[0] > agg['sum_skill'].iloc[1]:\n",
    "            res = agg.index[0]\n",
    "            stats['skill_based'] += 1\n",
    "        else:\n",
    "            # top-3 answers by overall skills\n",
    "            vals = list(map(lambda x: user2skill[x], users))\n",
    "            table = pd.DataFrame({'user': users, 'votes': votes, 'skill': vals})\n",
    "            table = table.sort_values(by='skill', ascending=False)\n",
    "            if len(table) >= 3:\n",
    "                sub = table.iloc[:3]\n",
    "            else:\n",
    "                sub = table\n",
    "            agg = sub.groupby('votes').agg(\n",
    "                sum_skill=pd.NamedAgg(column='skill', aggfunc='sum'),\n",
    "                sum_votes=pd.NamedAgg(column='user', aggfunc='count')\n",
    "            ).sort_values(by=['sum_votes', 'sum_skill'], ascending=False)\n",
    "            if agg['sum_skill'].iloc[0] != agg['sum_skill'].iloc[1]:\n",
    "                res = agg.index[0]\n",
    "                stats['major_based'] += 1\n",
    "            \n",
    "            else:\n",
    "                vals = list(map(lambda x: user2alpha[x], users))\n",
    "                table = pd.DataFrame({'user': users, 'votes': votes, 'skill': vals})\n",
    "                agg = table.groupby('votes').agg(\n",
    "                    sum_skill=pd.NamedAgg(column='skill', aggfunc='sum'),\n",
    "                    sum_votes=pd.NamedAgg(column='user', aggfunc='count')\n",
    "                ).sort_values(by=['sum_votes', 'sum_skill'], ascending=False)\n",
    "                # check there is a leader by skills\n",
    "                if agg['sum_skill'].iloc[0] != agg['sum_skill'].iloc[1]:\n",
    "                    res = agg.index[0]\n",
    "                    stats['em_based'] += 1\n",
    "                else:\n",
    "                    res = agg.index[0]\n",
    "                    stats['rest'] += 1\n",
    "\n",
    "    preds_full[keys[i]] = res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'total_agreement': 218,\n",
       " 'majority': 190,\n",
       " 'skill_based': 3,\n",
       " 'major_based': 16,\n",
       " 'em_based': 11,\n",
       " 'rest': 0}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_full_df = pd.concat([pd.DataFrame(preds_full.keys(), columns=['premise', 'hypothesis']), pd.DataFrame(preds_full.values(), columns=['lb'])], axis=1).astype(str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Сопоставление разметки и ground truth"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Забираем задания из датасета с правильными ответами."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df = pd.read_csv('general_wa.tsv', sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df = res_df.rename({\n",
    "    'INPUT:premise': 'premise',\n",
    "    'INPUT:hypothesis': 'hypothesis',\n",
    "    'GOLDEN:answer': 'lb',\n",
    "}, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "После скачивания с Толоки в текстах рушится форматирование, потому нельзя просто сделать join двух табличек. Нужно убрать все \"лишнее\" форматирование сразу из двух табличек, чтобы остались только тексты, пунктуация и пробелы."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>lb</th>\n",
       "      <th>premise</th>\n",
       "      <th>hypothesis</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "      <td>Мужчина уже был ранее судим, рассказали «Комсо...</td>\n",
       "      <td>Мужчину раньше уже судили.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   lb                                            premise  \\\n",
       "0   3  Мужчина уже был ранее судим, рассказали «Комсо...   \n",
       "\n",
       "                   hypothesis  \n",
       "0  Мужчину раньше уже судили.  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_text(text):\n",
    "    text = (text.strip().replace('\\n', ' ').replace('\\t', ' ')\n",
    "            .replace('\\r', ' ').replace('  ', ' ').replace('  ', ' ')\n",
    "            .replace('  ', ' '))\n",
    "    return text\n",
    "\n",
    "res_df['premise'] = res_df['premise'].apply(format_text)\n",
    "res_df['hypothesis'] = res_df['hypothesis'].apply(format_text)\n",
    "\n",
    "preds_full_df['premise'] = preds_full_df['premise'].apply(format_text)\n",
    "preds_full_df['hypothesis'] = preds_full_df['hypothesis'].apply(format_text)\n",
    "\n",
    "res_df['full'] = res_df['premise'] + ' ' + res_df['hypothesis']\n",
    "preds_full_df['full'] = preds_full_df['premise'] + ' ' + preds_full_df['hypothesis']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Делаем left join, чтобы соединить голосование и правильные метки для одних и тех же заданий."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "new = res_df.merge(preds_full_df.drop(['premise', 'hypothesis'], axis=1), on='full', how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "438"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_valid = new[new['lb_y'].notna()].copy()\n",
    "len(new_valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_valid['lb_y'] = new_valid['lb_y'].astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ни одна строка не была утеряна."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>lb_x</th>\n",
       "      <th>premise</th>\n",
       "      <th>hypothesis</th>\n",
       "      <th>full</th>\n",
       "      <th>lb_y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "      <td>Мужчина уже был ранее судим, рассказали «Комсо...</td>\n",
       "      <td>Мужчину раньше уже судили.</td>\n",
       "      <td>Мужчина уже был ранее судим, рассказали «Комсо...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   lb_x                                            premise  \\\n",
       "0     3  Мужчина уже был ранее судим, рассказали «Комсо...   \n",
       "\n",
       "                   hypothesis  \\\n",
       "0  Мужчину раньше уже судили.   \n",
       "\n",
       "                                                full  lb_y  \n",
       "0  Мужчина уже был ранее судим, рассказали «Комсо...     1  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_valid.head(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Подсчет метрики"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Если в правом столбце меток осталось 438 непустых строк, значит, форматирование было подчищено корректно и ничего не потерялось"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Попробуем посчитать разные метрики"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.58675799086758"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(new_valid['lb_x'] == new_valid['lb_y']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5651979834870691"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_score(new_valid['lb_x'], new_valid['lb_y'], average='macro')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Accuracy = 0.587`, `F1_macro = 0.565`"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
